import faiss
import numpy as np
import pickle
from sentence_transformers import SentenceTransformer
import torch
import json
from typing import List, Dict, Set

class RAG_INFERENCE():

    def __init__(self, embedding_model_path, faiss_index_path, docs_path, out_file_path, knowledge_matrix_path):

        self.docs_path = docs_path
        self.embed_model = SentenceTransformer(embedding_model_path)

        self.index = faiss.read_index(faiss_index_path)
        with open(docs_path, 'rb') as f:
            self.documents = pickle.load(f)
        self.out_file_path = out_file_path

        self.knowledge_matrix = np.load(knowledge_matrix_path)  # shape: [n_docs, m_dim]

    def search_docs(self, query: str, top_k: int = 10) -> List[str]:
        query_vec = self.embed_model.encode([query])
        D, I = self.index.search(np.array(query_vec), top_k)
        return [self.documents[i] for i in I[0]]

    def calculate_mrr(self, true_indexes: Set[int], retrieved_indices: List[int], knowledge_matrix: np.ndarray) -> float:
        """ MRR（Mean Reciprocal Rank）"""
        if not retrieved_indices or not true_indexes:
            return 0.0
        selected_vectors = knowledge_matrix[retrieved_indices]
        retrieved_union = np.any(selected_vectors, axis=0).astype(int)
        covered_kp_ids = set(np.where(retrieved_union == 1)[0])
        intersection = covered_kp_ids & true_indexes

        for rank, idx in enumerate(retrieved_indices, 1):
            if any(knowledge_matrix[idx, kp] == 1 for kp in intersection):
                return 1.0 / rank
        return 0.0

    def calculate_recall_at_k(self, true_indexes: Set[int], retrieved_indices: List[int], knowledge_matrix: np.ndarray) -> float:
        """ Recall@10"""
        if not retrieved_indices or not true_indexes:
            return 0.0
        selected_vectors = knowledge_matrix[retrieved_indices]
        retrieved_union = np.any(selected_vectors, axis=0).astype(int)
        covered_kp_ids = set(np.where(retrieved_union == 1)[0])
        intersection_count = len(covered_kp_ids & true_indexes)
        return intersection_count / len(true_indexes) if len(true_indexes) > 0 else 0.0

    def evaluate_coverage(self, questions_path: str, top_k: int = 10) -> Dict:
        with open(questions_path, "r", encoding="utf-8") as f:
            question_data = json.load(f)

        detailed_results = []
        coverage_rates, mrr_scores, recall_scores = [], [], []

        for i, item in enumerate(question_data):
            try:
                question_text = item.get('result')
                true_indexes = set(item['indexs'])

                query = f"{question_text}"
                retrieved_docs = self.search_docs(query, top_k=top_k)
                retrieved_indices = [self.documents.index(doc) for doc in retrieved_docs if doc in self.documents]

                if retrieved_indices:
                    selected_vectors = self.knowledge_matrix[retrieved_indices]
                    retrieved_union = np.any(selected_vectors, axis=0).astype(int)
                    covered_kp_ids = set(np.where(retrieved_union == 1)[0])
                else:
                    covered_kp_ids = set()

                intersection_count = len(covered_kp_ids & true_indexes)
                total_kp_count = len(true_indexes)
                coverage_rate = intersection_count / total_kp_count if total_kp_count > 0 else 0.0
                mrr = self.calculate_mrr(true_indexes, retrieved_indices, self.knowledge_matrix)
                recall_at_10 = self.calculate_recall_at_k(true_indexes, retrieved_indices, self.knowledge_matrix)

                coverage_rates.append(coverage_rate)
                mrr_scores.append(mrr)
                recall_scores.append(recall_at_10)

                detailed_results.append({
                    "idx": int(i + 1),
                    "coverage_rate": round(coverage_rate, 4),
                    "mrr": round(mrr, 4),
                    "recall@10": round(recall_at_10, 4),
                    "retrieved_doc_ids": [int(x) for x in retrieved_indices],
                    "expected_knowledge_indexes": [int(x) for x in sorted(true_indexes)],
                    "retrieved_coverage_indexes": [int(x) for x in sorted(covered_kp_ids)],
                    "covered_count": intersection_count,
                    "total_required": total_kp_count
                })

            except Exception as e:
                print(e)

        avg_coverage = sum(coverage_rates) / len(coverage_rates) if coverage_rates else 0.0
        avg_mrr = sum(mrr_scores) / len(mrr_scores) if mrr_scores else 0.0
        avg_recall = sum(recall_scores) / len(recall_scores) if recall_scores else 0.0
        print(f"\nCoverage Rate：{avg_coverage:.2%}")
        print(f"AVG MRR：{avg_mrr:.4f}")
        print(f"AVG Recall@10：{avg_recall:.2%}")

        coverage_out_path = self.out_file_path.replace(".txt", "_coverage_metrics.json")
        with open(coverage_out_path, "w", encoding="utf-8") as f:
            json.dump(detailed_results, f, indent=2, ensure_ascii=False)
        print(f"✅ Saved: {coverage_out_path}")

        return {
            "avg_coverage_rate": avg_coverage,
            "avg_mrr": avg_mrr,
            "avg_recall_at_10": avg_recall
        }

if __name__ == "__main__":
    rag_full = RAG_INFERENCE(
        embedding_model_path="embedding_model_EN",
        faiss_index_path="faiss_index/doc.index",
        docs_path="faiss_index/docs.pkl",
        out_file_path="out.txt",
        knowledge_matrix_path="kp_matrix.npy"
    )
    metrics_full = rag_full.evaluate_coverage("questions_Choices.json", top_k=50)
    print(f"{metrics_full}")

